from yacs.config import CfgNode as CN
import logging

# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------
_C = CN()

# Project name
_C.NAME = ""
# Checkpoint save dir
_C.SAVE_DIR = ""
# Device
_C.DEVICE = "cuda"

# Model Related
_C.MODEL = CN()
_C.MODEL.EMBED_SIZE = 1024

_C.MODEL.ATTRIBUTE = CN()
_C.MODEL.ATTRIBUTE.EMBED_SIZE = 512

_C.MODEL.GLOBAL = CN()
_C.MODEL.GLOBAL.BACKBONE = CN()
_C.MODEL.GLOBAL.BACKBONE.NAME = 'resnet50'
_C.MODEL.GLOBAL.BACKBONE.EMBED_SIZE = 1024
_C.MODEL.GLOBAL.ATTENTION = CN()
_C.MODEL.GLOBAL.ATTENTION.SPATIAL = CN()
_C.MODEL.GLOBAL.ATTENTION.SPATIAL.ENABLE = True
_C.MODEL.GLOBAL.ATTENTION.SPATIAL.COMMON_EMBED_SIZE = 512
_C.MODEL.GLOBAL.ATTENTION.CHANNEL = CN()
_C.MODEL.GLOBAL.ATTENTION.CHANNEL.ENABLE = True
_C.MODEL.GLOBAL.ATTENTION.CHANNEL.REDUCTION_RATE = 2

_C.MODEL.LOCAL = CN()
_C.MODEL.LOCAL.ENABLE = False
_C.MODEL.LOCAL.BACKBONE = CN()
_C.MODEL.LOCAL.BACKBONE.NAME = 'resnet34'
_C.MODEL.LOCAL.BACKBONE.EMBED_SIZE = 256
_C.MODEL.LOCAL.ATTENTION = CN()
_C.MODEL.LOCAL.ATTENTION.SPATIAL = CN()
_C.MODEL.LOCAL.ATTENTION.SPATIAL.ENABLE = True
_C.MODEL.LOCAL.ATTENTION.SPATIAL.COMMON_EMBED_SIZE = 256
_C.MODEL.LOCAL.ATTENTION.CHANNEL = CN()
_C.MODEL.LOCAL.ATTENTION.CHANNEL.ENABLE = True
_C.MODEL.LOCAL.ATTENTION.CHANNEL.REDUCTION_RATE = 2

# Input option
_C.INPUT = CN()
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
_C.INPUT.GLOBAL_SIZE = 224
_C.INPUT.LOCAL_SIZE = 112
_C.INPUT.THRESHOLD = 0.4

# Dataset option
_C.DATA = CN()
_C.DATA.NUM_WORKERS = 4
_C.DATA.NUM_TRIPLETS = 100000
_C.DATA.TRAIN_BATCHSIZE = 16
_C.DATA.TEST_BATCHSIZE = 64
_C.DATA.BASE_PATH = "data"
_C.DATA.DATASET = ""
_C.DATA.NUM_ATTRIBUTES = -1

# Attributes, refer to specific dataset configuration file
_C.DATA.ATTRIBUTES = CN(new_allowed=True)

_C.DATA.PATH_FILE = CN()
_C.DATA.PATH_FILE.TRAIN = ""
_C.DATA.PATH_FILE.VALID = ""
_C.DATA.PATH_FILE.TEST = ""

_C.DATA.GROUNDTRUTH = CN(new_allowed=True)

_C.DATA.GROUNDTRUTH.QUERY = CN()
_C.DATA.GROUNDTRUTH.QUERY.TEST = ""
_C.DATA.GROUNDTRUTH.QUERY.VALID = ""

_C.DATA.GROUNDTRUTH.CANDIDATE = CN()
_C.DATA.GROUNDTRUTH.CANDIDATE.TEST = ""
_C.DATA.GROUNDTRUTH.CANDIDATE.VALID = ""

# SOLVER
_C.SOLVER = CN()
_C.SOLVER.EVAL_STEPS = 1
_C.SOLVER.OPTIMIZER_NAME = "Adam"
_C.SOLVER.STEP_SIZE = 3
_C.SOLVER.DECAY_RATE = 0.9
_C.SOLVER.EPOCHS = 50
_C.SOLVER.LOG_PERIOD = 800
_C.SOLVER.BASE_LR = 1e-4
_C.SOLVER.BASE_LR_SLOW = 1e-5
_C.SOLVER.ALIGN_WEIGHT = 1.
_C.SOLVER.LOCAL_WEIGHT = 1.
_C.SOLVER.GLOBAL_WEIGHT = 1.
_C.SOLVER.MARGIN = 0.2
_C.SOLVER.BETA = 0.6

# Logger
_C.LOGGER = CN()
_C.LOGGER.LEVEL = logging.INFO
_C.LOGGER.STREAM = "stdout"